/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "aclnn_faster_gelu_custom.h"
#include "acl/acl_rt.h"
#include "acl/acl.h"
#include <stdio.h>
#include <stdlib.h>
#include "../../../../../common/data_utils.h"

aclrtStream CreateStream(const int32_t device)
{
    if (aclInit(nullptr) != ACL_SUCCESS) {
        printf("acl init failed\n");
        return nullptr;
    }
    if (aclrtSetDevice(device) != ACL_SUCCESS) {
        printf("Set device failed\n");
        CHECK_ACL(aclFinalize());
        return nullptr;
    }
    aclrtStream stream = nullptr;
    if (aclrtCreateStream(&stream) != ACL_SUCCESS) {
        printf("Create stream failed\n");
        CHECK_ACL(aclFinalize());
        return nullptr;
    }
    return stream;
}

void DestroyStream(aclrtStream stream, const int32_t device)
{
    CHECK_ACL(aclrtDestroyStream(stream));
    if (aclrtResetDevice(device) != ACL_SUCCESS) {
        printf("Reset device failed\n");
    }
    if (aclFinalize() != ACL_SUCCESS) {
        printf("Finalize acl failed\n");
    }
}

void DestroyTensor(aclTensor* tensors[], void* devMem[], const int32_t tensorCount)
{
    for (auto i = 0; i < tensorCount; i++) {
        if (!tensors[i]) {
            continue;
        }
        if (devMem[i]) {
            CHECK_ACL(aclrtFree(devMem[i]));
        }
        CHECK_ACL(aclDestroyTensor(tensors[i]));
    }
}

struct tensorInfo {
    int64_t* dims;
    int64_t dimCnt;
    aclDataType dtype;
    aclFormat fmt;
};

int64_t GetDataSize(const struct tensorInfo* desc)
{
    if (!desc->dims) {
        return 0;
    }
    int64_t size = 1;
    for (auto i = 0; i < desc->dimCnt; i++) {
        size *= desc->dims[i];
    }
    return size * sizeof(float);
}

static bool CompareResult(const void* outputData, const int64_t outSize)
{
    void* goldenData;
    CHECK_ACL(aclrtMallocHost((void**)(&goldenData), outSize));
    size_t goldenSize = outSize;
    bool ret = ReadFile("../output/golden.bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden.bin success!\n");
    } else {
        CHECK_ACL(aclrtFreeHost(goldenData));
        return false;
    }
    constexpr float EPS = 1e-6;
    int64_t wrongNum = 0;

    for (auto i = 0; i < outSize / sizeof(float); i++) {
        float a = (reinterpret_cast<const float*>(outputData))[i];
        float b = (reinterpret_cast<const float*>(goldenData))[i];
        float ae = std::abs(a - b);
        float re = ae / abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult golden.bin failed output is %lf, golden is %lf\n", a, b);
            wrongNum++;
        }
    }
    CHECK_ACL(aclrtFreeHost(goldenData));

    if (wrongNum != 0) {
        return false;
    } else {
        printf("CompareResult golden.bin success\n");
        return true;
    }
}

int32_t main(void)
{
    int64_t input[] = { 1024 };
    int64_t output[] = { 1024 };
    const struct tensorInfo tensorDesc[] = {
        { input, 1, ACL_FLOAT, ACL_FORMAT_ND },
        { output, 1, ACL_FLOAT, ACL_FORMAT_ND },
    };
    aclrtStream stream = CreateStream(0);
    if (stream == nullptr) {
        return -1;
    }
    const int32_t tensorCount = sizeof(tensorDesc) / sizeof(struct tensorInfo);
    aclTensor* tensors[tensorCount];
    void* devMem[tensorCount];
    for (auto i = 0; i < tensorCount; i++) {
        void* data;
        const struct tensorInfo* info = &(tensorDesc[i]);
        int64_t size = GetDataSize(info);
        if (size == 0) {
            tensors[i] = nullptr;
            devMem[i] = nullptr;
            continue;
        }
        CHECK_ACL(aclrtMalloc(&data, size, ACL_MEM_MALLOC_HUGE_FIRST));
        // Allocate host memory and copy input data to host memory
        if (i == 0) {
            size_t inputSize = size;
            void* dataHost;
            CHECK_ACL(aclrtMallocHost((void**)(&dataHost), inputSize));
            ReadFile("../input/input.bin", inputSize, dataHost, inputSize);
            CHECK_ACL(aclrtMemcpy(data, size, dataHost, size, ACL_MEMCPY_HOST_TO_DEVICE));
            CHECK_ACL(aclrtFreeHost(dataHost));
        }
        devMem[i] = data;
        tensors[i] = aclCreateTensor(info->dims, info->dimCnt, info->dtype, nullptr, 0, info->fmt, info->dims,
                                     info->dimCnt, data);
    }

    size_t workspaceSize = 0;
    aclOpExecutor* handle;
    // Tensor order is input, output
    int32_t ret = aclnnFasterGeluCustomGetWorkspaceSize(tensors[0], tensors[1], &workspaceSize, &handle);
    if (ret != ACL_SUCCESS) {
        printf("aclnnFasterGeluCustomGetWorkspaceSize failed. error code is %d\n", ret);
        DestroyTensor(tensors, devMem, tensorCount);
        DestroyStream(stream, 0);
        return ret;
    }
    printf("aclnnFasterGeluCustomGetWorkspaceSize ret %d workspace size %lu\n", ret, workspaceSize);
    void* workspace = nullptr;
    if (workspaceSize != 0) {
        CHECK_ACL(aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    }
    ret = aclnnFasterGeluCustom(workspace, workspaceSize, handle, stream);
    printf("aclnnFasterGeluCustom ret %d\n", ret);
    CHECK_ACL(aclrtSynchronizeStream(stream));

    uint8_t* outputHost;
    int64_t outputHostSize = GetDataSize(&(tensorDesc[1]));

    CHECK_ACL(aclrtMallocHost((void**)(&outputHost), outputHostSize));
    CHECK_ACL(aclrtMemcpy(outputHost, outputHostSize, devMem[1], outputHostSize, ACL_MEMCPY_DEVICE_TO_HOST));
    WriteFile("../output/output.bin", outputHost, outputHostSize);

    bool goldenResult = CompareResult(outputHost, outputHostSize);
    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }
    if (workspaceSize != 0) {
        CHECK_ACL(aclrtFree(workspace));
    }
    CHECK_ACL(aclrtFreeHost(outputHost));

    DestroyTensor(tensors, devMem, tensorCount);
    DestroyStream(stream, 0);
    return 0;
}
